import numpy as np
import math
import random

from ase.io import read
from ase import Atoms

from UnitCell_Environment.unitcell_environment.env.agent import Agent
from UnitCell_Environment.unitcell_environment.env.utils import COMP, random_airss_structure
import datetime

from chgnet.model.model import CHGNet
from pymatgen.io.cif import *
from pymatgen.io.ase import AseAtomsAdaptor


class World:  # multi-agent world
    
    def __init__(self, config):

        # self.cell = config.get("cell", None)
        self.debug = config.get("debug", False)
        self.neighbors_limit = config.get("neighbors_limit", 10)
        self.comp_list = config["comp"].split(",")
        self.comp_name = self.comp_list[0]
        self.cell_fixed = config.get("cell_fixed", False)
        self.use_energy_feature = config.get("use_energy_feature", False)
        self.use_log_gnorm_feature = config.get("use_log_gnorm_feature", True)
        self.use_gnorm_feature = config.get("use_gnorm_feature", False)
        self.use_grad_feature = config.get("use_grad_feature", True)
        self.use_d_grad_feature = config.get("use_d_grad_feature", False)
        self.use_dd_grad_feature = config.get("use_dd_grad_feature", False)
        self.use_last_step_feature = config.get("use_last_step_feature", False)
        self.use_2ndlast_step_feature = config.get("use_2ndlast_step_feature", False)
        self.use_log_fmax_feature = config.get("use_log_fmax_feature", False)
        self.use_mass_feature = config.get("use_mass_feature", False)
        self.use_cell_feature = config.get("use_cell_feature", False)
        self.max_grad_value = config.get("max_grad_value", 500)
        self.variable_step_size = config.get("variable_step_size", None)
        self.varss_c4 = config.get("varss_c4", 0.5)
        self.min_gnorm = config.get("min_gnorm", 0.0001)

        self.chgnet = CHGNet.load(use_device="cuda" if config.get("gpu", False) else "cpu")

        # dict of agents
        self.agents = {}
        self.env_distance = 0
        self.energy = 0
        self.min_energy = 0
        self.fmax = 0
        self.fmean = 0

        self.min_atoms = None
        self.time = 0

        self.initialize(cif=config.get("cif", None))

    
    def initialize(self, seed=None, cif=None, update_atoms=True, comp=None):

        self.energy = 0
        self.test_energy = 0
        self.fmax = 0
        self.fmean = 0

        self.atoms_history = []

        # if we should create a new structure to start optimization
        if update_atoms:

            self.comp_name = comp if comp is not None else random.choice(self.comp_list)
            self.comp = COMP[self.comp_name]

            self.create_starting_structure(seed=seed, cif=cif)

        self.update_energy()

        self.min_atoms = self.atoms.copy()
        self.min_energy = self.energy

        self.test_energy = 0 #removed to speed up learning
        # self.test_energy = self.relaxed_energy()

        
    def create_starting_structure(self, seed=None, cif=None):

        species = list(self.comp.keys())
        positions = []

        if cif != None:
            atoms = read(cif)

        else:
            atoms = random_airss_structure(comp=self.comp_name, cell_fixed=self.cell_fixed)

        self.atoms = Atoms(cell=atoms.cell, pbc=[True, True, True])
        self.cell = atoms.cell
        # print(self.atoms.cell)
        
        for a_name in species:
            for i in range(len(atoms)):
                if atoms[i].symbol == a_name:
                    positions.append(atoms.positions[i])

        self.names_by_indices = {}
        self.agents = {}

        species = list(self.comp.keys())

        # create agents with given positions within the unit cell
        for a_name in species:

            for i in range(self.comp[a_name][0]):

                atoms = Atoms(a_name, positions=[positions[len(self.atoms)]],
                            cell=self.atoms.cell, pbc=self.atoms.pbc)
                self.atoms += atoms
                
                agent = Agent(id=len(self.atoms) - 1, name=f"{a_name}_{i}", agents=self.agents, world=self)

                self.agents[agent.name] = agent
                self.names_by_indices[agent.id] = agent.name


    # returns vector with general structure information including: 
    # unit cell size, number of atoms, number of species and charge for each atom type
    def features(self):

        log_fmax_f = [math.log2(max(self.min_gnorm, self.fmax))] if self.use_log_fmax_feature else []

        e_f = [self.energy] if self.use_energy_feature else []

        cell_f = list(np.concatenate((self.cell[0], self.cell[1], self.cell[2]))) if self.use_cell_feature else []

        return log_fmax_f + e_f + cell_f


    # let the agent make move and change the state accordingly
    def take_action(self, action_type, actions):
        
        self.atoms_history.append(self.atoms.copy())
        self.print_debug(f"Action: {action_type}, {actions.values()}, agents: {actions.keys()}")

        move_vectors = [actions[agent.name] for agent in self.agents.values()]
        self.parallel_moves(move_vectors)

        self.update_energy()

        return True


    def move_vectors(self, step_size):

        res = []

        for i in range(6):

            vector = [random.uniform(-1, 1) for _ in range(3)]
            length = math.sqrt(sum(x**2 for x in vector))
            normalized_vector = [x * step_size / length for x in vector]
            res.append(normalized_vector)

        return res


    def save_checkpoint(self):
        self.checkpoint = {'atoms': self.atoms.copy(), 
                           'energy': self.energy}
        
        for a in self.agents:
            self.checkpoint[a] = {'energy': self.agents[a].energy, 
                                  'gradient': self.agents[a].gradient, 
                                  'prev_gradient': self.agents[a].prev_gradient,
                                  'dgrad': self.agents[a].dgrad,
                                  'prev_dgrad': self.agents[a].prev_dgrad,
                                  'gnorm': self.agents[a].gnorm, 
                                  'prev_energy': self.agents[a].prev_energy, 
                                  'prev_gnorm': self.agents[a].prev_gnorm}

    def revert_to_checkpoint(self):

        self.atoms = self.checkpoint['atoms']
        self.energy = self.checkpoint['energy']
        
        for a in self.agents:

            self.agents[a].energy = self.checkpoint[a]['energy']
            self.agents[a].gradient = self.checkpoint[a]['gradient']
            self.agents[a].prev_gradient = self.checkpoint[a]['prev_gradient']
            self.agents[a].dgrad = self.checkpoint[a]['dgrad']
            self.agents[a].prev_dgrad = self.checkpoint[a]['prev_dgrad']
            self.agents[a].gnorm = self.checkpoint[a]['gnorm']
            self.agents[a].prev_energy = self.checkpoint[a]['prev_energy']
            self.agents[a].prev_gnorm = self.checkpoint[a]['prev_gnorm']


    def update_positions(self, atom_ids, positions):

        atom_ids = [atom_ids] if isinstance(atom_ids, int) else atom_ids
        positions = positions if positions.ndim == 2 else [positions]

        for i, id in enumerate(atom_ids):

            assert id < len(self.atoms.positions), f"Error: id {id} is not in positions of length {len(self.atoms.positions)}"
            assert i < len(positions), f"Error: i {i} is not in positions of length {len(positions)}"
            self.atoms.positions[id] = positions[i]


    def log_gnorms(self):

        return [math.log2(max(a.gnorm, self.min_gnorm))for a in self.agents.values()]


    def update_from_data(self, data):

        self.atoms_history.append(self.atoms.copy())

        self.save_checkpoint()

        self.energy = data["energy"]
        self.gradients = data["forces"]
        site_energies = data['site_energies']

        for agent in self.agents.values():

            agent.prev_gradient = agent.gradient
            agent.gradient = self.gradients[agent.id]
            agent.prev_dgrad = agent.dgrad
            agent.dgrad = agent.gradient - agent.prev_gradient
            agent.prev_gnorm = agent.gnorm
            agent.prev_energy = agent.energy
            agent.gnorm = np.linalg.norm(agent.gradient) 
            agent.energy = site_energies[agent.id]

        self.update_positions(range(len(self.atoms)), data["atom_positions"])

        self.update_min_mean_max()


    # recalculate the energy according to the agents new positions
    # if agents is None energy of all agents will be updated
    def update_energy(self):
        
        self.update_energy_chgnet()
        self.update_min_mean_max()


    def update_energy_chgnet(self):

        temp_atoms = AseAtomsAdaptor.get_structure(self.atoms)

        start = datetime.now()
        result = self.chgnet.predict_structure(temp_atoms, task='ef', return_site_energies=True)
        self.time += (datetime.now() - start).total_seconds()
        self.energy = float(result['e'])
        self.print_debug(f"ChgNET Energy: {self.energy}")

        self.gradients = result['f']
        site_energies = result['site_energies']
        self.print_debug(f"ChgNET Forces: {self.gradients}")
           
        for agent in self.agents.values():

            self.gradients[agent.id] = self.normalize_grad(self.gradients[agent.id])
            agent.prev_gradient = agent.gradient
            agent.gradient = self.gradients[agent.id]
            agent.prev_dgrad = agent.dgrad
            agent.dgrad = agent.gradient - agent.prev_gradient
            agent.prev_gnorm = agent.gnorm
            agent.prev_energy = agent.energy
            agent.gnorm = np.linalg.norm(agent.gradient) 
            agent.energy = site_energies[agent.id]


    def parallel_moves(self, move_vectors):

        start_pos = self.atoms.positions.copy()
        pos = self.positions_in_cell(start_pos + move_vectors)

        self.save_checkpoint()
        old_energy = self.energy

        self.update_positions(range(len(move_vectors)), pos)    

        for i in range(len(move_vectors)):
            self.agents[self.names_by_indices[i]].prev_last_step = self.agents[self.names_by_indices[i]].last_step
            self.agents[self.names_by_indices[i]].last_step = move_vectors[i]

        self.print_debug(f"Update pos on {move_vectors}, start_pos: {start_pos}, new pos: {pos}, pos from ase: {self.atoms.positions}")


    def print_debug(self, text):

        if self.debug:
            print(text)


    def positions_in_cell(self, positions):

        cell_matrix = np.vstack([self.cell[0], self.cell[1], self.cell[2]]).T

        frac_positions = np.linalg.solve(cell_matrix, positions.T).T
        frac_positions %= 1.0

        return frac_positions @ cell_matrix.T

    # returns the position of the atom in the unitcell and offset
    def position_in_cell(self, pos):
        
        # print(self.cell)
        A = np.column_stack((self.cell[0], self.cell[1], self.cell[2]))

        # Solve for the coefficients using linear algebra
        coefficients = np.linalg.solve(A, pos.reshape(3, 1))

        # Print the coefficients
        # print(f'x = {coefficients[0]}, y = {coefficients[1]}, z = {coefficients[2]}')

        new_pos = (coefficients[0]%1)*np.asarray(self.cell[0]) \
              + (coefficients[1]%1)*np.asarray(self.cell[1]) \
              + (coefficients[2]%1)*np.asarray(self.cell[2])
        
        offset = np.floor(coefficients).astype(int)

        return offset, new_pos
    

    def update_min_mean_max(self):

        self.fmax = max(agent.gnorm for agent in self.agents.values())
        self.fmean = np.mean([agent.gnorm for agent in self.agents.values()])

        if self.min_energy > self.energy:
            self.min_energy = self.energy
            self.min_atoms = self.atoms.copy()


    def normalize_grad(self, grad):

        max_abs_v = max(max(grad), -min(grad))

        if max_abs_v > self.max_grad_value:
            return grad / max_abs_v * self.max_grad_value
        
        return grad
    

    def variable_step_size_coef(self, agent):

        if self.variable_step_size == "gnorm":
            return min(agent.gnorm, self.varss_c4)
        
        return 1
    
           